Code and visualizations to test, debug, and evaluate the Mask R-CNN model.
import os
import sys
import random
import math
import re
import time
import numpy as np
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pickle
import cv2
# Root directory of the project
ROOT_DIR = os.path.abspath("../../")
import skimage
# Import Mask RCNN
sys.path.append(ROOT_DIR) # To find local version of the library
from mrcnn import utils
from mrcnn import visualize
from mrcnn.config import Config
from mrcnn.visualize import display_images
import mrcnn.model as modellib
from mrcnn.model import log
import pandas as pd
%matplotlib inline
ROOT_DIR = os.path.abspath("../../")
MODEL_DIR = os.path.join(ROOT_DIR, "\\samples\\cars\\logs")
# Path to Ballon trained weights
# You can download this file from the Releases page
# https://github.com/matterport/Mask_RCNN/releases
cars_WEIGHTS_PATH = "D:\\Master TAID\\Anul2\\MLAV\\Car-Detection-Mask-R-CNN\\samples\\cars\\logs\\car20200125T0225" # TODO: update this path
BOXCARS_DATASET_ROOT = "D:\\Master TAID\\Anul2\\MLAV\\Car-Detection-Mask-R-CNN\\DataSet\\BoxCars116k"
#%%
BOXCARS_IMAGES_ROOT = os.path.join(BOXCARS_DATASET_ROOT, "images")
BOXCARS_DATASET = os.path.join(BOXCARS_DATASET_ROOT, "dataset.pkl")
BOXCARS_ATLAS = os.path.join(BOXCARS_DATASET_ROOT, "atlas.pkl")
BOXCARS_CLASSIFICATION_SPLITS = os.path.join(BOXCARS_DATASET_ROOT, "classification_splits.pkl")
############################################################
# Configurations
############################################################
class CarsConfig(Config):
"""Configuration for training on the car dataset.
Derives from the base Config class and overrides some values.
"""
# Give the configuration a recognizable name
NAME = "car"
# We use a GPU with 6GB memory, which can fit one image.
# Adjust up if you use a stronger GPU.
IMAGES_PER_GPU = 1
# Number of classes (including background)
NUM_CLASSES = 1 + 1 # Background + car
# Number of training steps per epoch
STEPS_PER_EPOCH = 100
# Skip detections with < 90% confidence
DETECTION_MIN_CONFIDENCE = 0.9
class CarsDataset(utils.Dataset):
def initialize_data(self, part):
self.X = {}
self.Y = {}
self.part=part
self.add_class("car", 1, "car")
for part in ("train", "validation", "test"):
self.X[part] = None
self.Y[part] = None # for labels as array of 0-1 flags
self.dataset = self.load_cache(BOXCARS_DATASET)
self.atlas = self.load_cache(BOXCARS_ATLAS)
self.split = self.load_cache(BOXCARS_CLASSIFICATION_SPLITS)['hard']
self.nr_of_classes = len(self.split["types_mapping"])
self.df = pd.read_pickle(BOXCARS_DATASET)
assert self.split is not None, "load classification split first"
assert part in self.X, "unknown part -- use: train, validation, test"
assert self.X[part] is None, "part %s was already initialized" % part
data = self.split[self.part]
x, y = [], []
for vehicle_id, label in data:
num_instances = len(self.dataset["samples"][vehicle_id]["instances"])
x.extend([(vehicle_id, instance_id) for instance_id in range(num_instances)])
y.extend([label] * num_instances)
self.X[self.part] = np.asarray(x, dtype=int)
for x in self.X[self.part]:
vehicle_id, instance_id = x
image = self.get_image_by_id(vehicle_id, instance_id)
height, width = image.shape[:2]
image_path=self.df['samples'][vehicle_id]['instances'][instance_id]['path']
_, filename = os.path.split(image_path)
image_path=os.path.join(BOXCARS_IMAGES_ROOT,image_path)
self.add_image(
"car",
image_id=filename, # use file name as a unique image id
path=image_path,
width=width, height=height,
polygons=1)
def load_cache(self,path, encoding="latin-1", fix_imports=True):
with open(path, "rb") as f:
return pickle.load(f, encoding=encoding, fix_imports=True)
def get_image(self, image_id):
"""
returns decoded image from atlas in RGB channel order
"""
vehicle_id, instance_id = self.X[self.part][image_id]
return cv2.cvtColor(cv2.imdecode(self.atlas[vehicle_id][instance_id], 1), cv2.COLOR_BGR2RGB)
def get_image_by_id(self, vehicle_id, instance_id):
"""
returns decoded image from atlas in RGB channel order
"""
return cv2.cvtColor(cv2.imdecode(self.atlas[vehicle_id][instance_id], 1), cv2.COLOR_BGR2RGB)
def load_mask(self, image_id):
image_info = self.image_info[image_id]
if image_info["source"] != "car":
return super(self.__class__, self).load_mask(image_id)
image = self.get_image( image_id)
height, width = image.shape[:2]
mask = np.zeros([height, width, 1], dtype=np.uint8)
start, end = self.getMask2D( image_id)
rr, cc = skimage.draw.rectangle(start, end, shape=image.shape[:2])
mask[rr, cc] = 1
return mask.astype(np.bool), np.ones([mask.shape[-1]], dtype=np.int32)
def getMask3D(self, index):
vehicle_id, instance_id = self.X[self.part][index]
points = np.array(self.df['samples'][vehicle_id]['instances'][instance_id]['3DBB'], np.int32)
X = points[:, 0]
Y = points[:, 1]
return X, Y, np.array(points)
def getMask2D(self, index):
vehicle_id, instance_id = self.X[self.part][index]
x1, y1, x2, y2 = self.df['samples'][vehicle_id]['instances'][instance_id]['2DBB']
return (int(y1), int(x1)), (int(y2 + y1), int(x2 + x1))
def image_reference(self, image_id):
"""Return the path of the image."""
info = self.image_info[image_id]
if info["source"] == "car":
return info["path"]
else:
super(self.__class__, self).image_reference(image_id)
def train(model):
"""Train the model."""
# Training dataset.
dataset_train = CarsDataset()
dataset_train.initialize_data("train")
dataset_train.prepare()
# Validation dataset
dataset_val = CarsDataset()
dataset_val.initialize_data("validation")
dataset_val.prepare()
# *** This training schedule is an example. Update to your needs ***
# Since we're using a very small dataset, and starting from
# COCO trained weights, we don't need to train too long. Also,
# no need to train all layers, just the heads should do it.
print("Training network heads")
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=30,
layers='heads')
def color_splash(image, mask):
"""Apply color splash effect.
image: RGB image [height, width, 3]
mask: instance segmentation mask [height, width, instance count]
Returns result image.
"""
# Make a grayscale copy of the image. The grayscale copy still
# has 3 RGB channels, though.
gray = skimage.color.gray2rgb(skimage.color.rgb2gray(image)) * 255
# Copy color pixels from the original color image where mask is set
if mask.shape[-1] > 0:
# We're treating all instances as one, so collapse the mask into one layer
mask = (np.sum(mask, -1, keepdims=True) >= 1)
splash = np.where(mask, image, gray).astype(np.uint8)
else:
splash = gray.astype(np.uint8)
return splash
def detect_and_color_splash(model, image_path=None, video_path=None):
assert image_path or video_path
# Image or video?
if image_path:
# Run model detection and generate the color splash effect
print("Running on {}".format(args.image))
# Read image
image = skimage.io.imread(args.image)
# Detect objects
r = model.detect([image], verbose=1)[0]
# Color splash
splash = color_splash(image, r['masks'])
# Save output
file_name = "splash_{:%Y%m%dT%H%M%S}.png".format(datetime.datetime.now())
skimage.io.imsave(file_name, splash)
elif video_path:
import cv2
# Video capture
vcapture = cv2.VideoCapture(video_path)
width = int(vcapture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(vcapture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = vcapture.get(cv2.CAP_PROP_FPS)
# Define codec and create video writer
file_name = "splash_{:%Y%m%dT%H%M%S}.avi".format(datetime.datetime.now())
vwriter = cv2.VideoWriter(file_name,
cv2.VideoWriter_fourcc(*'MJPG'),
fps, (width, height))
count = 0
success = True
while success:
print("frame: ", count)
# Read next image
success, image = vcapture.read()
if success:
# OpenCV returns images as BGR, convert to RGB
image = image[..., ::-1]
# Detect objects
r = model.detect([image], verbose=0)[0]
# Color splash
splash = color_splash(image, r['masks'])
# RGB -> BGR to save image to video
splash = splash[..., ::-1]
# Add image to video writer
vwriter.write(splash)
count += 1
vwriter.release()
print("Saved to ", file_name)
config = CarsConfig()
CarDir = os.path.join(ROOT_DIR, "DataSet\\BoxCars116k\\images")
# Override the training configurations with a few
# changes for inferencing.
class InferenceConfig(config.__class__):
# Run detection on one image at a time
GPU_COUNT = 1
IMAGES_PER_GPU = 1
config = InferenceConfig()
config.display()
# Device to load the neural network on.
# Useful if you're training a model on the same
# machine, in which case use CPU and leave the
# GPU for training.
DEVICE = "/cpu:0" # /cpu:0 or /gpu:0
# Inspect the model in training or inference modes
# values: 'inference' or 'training'
# TODO: code for 'training' test mode not ready yet
TEST_MODE = "inference"
def get_ax(rows=1, cols=1, size=16):
"""Return a Matplotlib Axes array to be used in
all visualizations in the notebook. Provide a
central point to control graph sizes.
Adjust the size attribute to control how big to render images
"""
_, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
return ax
# Load validation dataset
dataset = CarsDataset()
dataset.initialize_data('validation')
# Must call before using the dataset
dataset.prepare()
print("Images: {}\nClasses: {}".format(len(dataset.image_ids), dataset.class_names))
# Create model in inference mode
with tf.device(DEVICE):
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR,
config=config)
# Set path to balloon weights file
# Download file from the Releases page and set its path
# https://github.com/matterport/Mask_RCNN/releases
# weights_path = "/path/to/mask_rcnn_balloon.h5"
# Or, load the last model you trained
weights_path ='D:\\Master TAID\\Anul2\\MLAV\\Car-Detection-Mask-R-CNN\\samples\\cars\\logs\\car20200125T0225\\mask_rcnn_car_0010.h5'
# Load weights
print("Loading weights ", weights_path)
model.load_weights(weights_path, by_name=True)
image_id = random.choice(dataset.image_ids)
image, image_meta, gt_class_id, gt_bbox, gt_mask =\
modellib.load_image_gt(dataset, config, image_id, use_mini_mask=False)
info = dataset.image_info[image_id]
print("image ID: {}.{} ({}) {}".format(info["source"], info["id"], image_id,
dataset.image_reference(image_id)))
# image=io.imread('D:\\Master TAID\\Anul2\\MLAV\\Car-Detection-Mask-R-CNN\\samples\\cars\\beach.jfif')
# Run object detection
results = model.detect([image], verbose=1)
# Display results
ax = get_ax(1)
r = results[0]
visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
dataset.class_names, r['scores'], ax=ax,
title="Predictions")
log("gt_class_id", gt_class_id)
log("gt_bbox", gt_bbox)
log("gt_mask", gt_mask)
This is for illustration. You can call balloon.py with the splash option to get better images without the black padding.
splash = color_splash(image, r['masks'])
display_images([splash], cols=1)
The Region Proposal Network (RPN) runs a lightweight binary classifier on a lot of boxes (anchors) over the image and returns object/no-object scores. Anchors with high objectness score (positive anchors) are passed to the stage two to be classified.
Often, even positive anchors don't cover objects fully. So the RPN also regresses a refinement (a delta in location and size) to be applied to the anchors to shift it and resize it a bit to the correct boundaries of the object.
The RPN targets are the training values for the RPN. To generate the targets, we start with a grid of anchors that cover the full image at different scales, and then we compute the IoU of the anchors with ground truth object. Positive anchors are those that have an IoU >= 0.7 with any ground truth object, and negative anchors are those that don't cover any object by more than 0.3 IoU. Anchors in between (i.e. cover an object by IoU >= 0.3 but < 0.7) are considered neutral and excluded from training.
To train the RPN regressor, we also compute the shift and resizing needed to make the anchor cover the ground truth object completely.
# Generate RPN trainig targets
# target_rpn_match is 1 for positive anchors, -1 for negative anchors
# and 0 for neutral anchors.
target_rpn_match, target_rpn_bbox = modellib.build_rpn_targets(
image.shape, model.anchors, gt_class_id, gt_bbox, model.config)
log("target_rpn_match", target_rpn_match)
log("target_rpn_bbox", target_rpn_bbox)
positive_anchor_ix = np.where(target_rpn_match[:] == 1)[0]
negative_anchor_ix = np.where(target_rpn_match[:] == -1)[0]
neutral_anchor_ix = np.where(target_rpn_match[:] == 0)[0]
positive_anchors = model.anchors[positive_anchor_ix]
negative_anchors = model.anchors[negative_anchor_ix]
neutral_anchors = model.anchors[neutral_anchor_ix]
log("positive_anchors", positive_anchors)
log("negative_anchors", negative_anchors)
log("neutral anchors", neutral_anchors)
# Apply refinement deltas to positive anchors
refined_anchors = utils.apply_box_deltas(
positive_anchors,
target_rpn_bbox[:positive_anchors.shape[0]] * model.config.RPN_BBOX_STD_DEV)
log("refined_anchors", refined_anchors, )
# Display positive anchors before refinement (dotted) and
# after refinement (solid).
visualize.draw_boxes(image, boxes=positive_anchors, refined_boxes=refined_anchors, ax=get_ax())
Here we run the RPN graph and display its predictions.
# Run RPN sub-graph
pillar = model.keras_model.get_layer("ROI").output # node to start searching from
# TF 1.4 and 1.9 introduce new versions of NMS. Search for all names to support TF 1.3~1.10
nms_node = model.ancestor(pillar, "ROI/rpn_non_max_suppression:0")
if nms_node is None:
nms_node = model.ancestor(pillar, "ROI/rpn_non_max_suppression/NonMaxSuppressionV2:0")
if nms_node is None: #TF 1.9-1.10
nms_node = model.ancestor(pillar, "ROI/rpn_non_max_suppression/NonMaxSuppressionV3:0")
rpn = model.run_graph([image], [
("rpn_class", model.keras_model.get_layer("rpn_class").output),
("pre_nms_anchors", model.ancestor(pillar, "ROI/pre_nms_anchors:0")),
("refined_anchors", model.ancestor(pillar, "ROI/refined_anchors:0")),
("refined_anchors_clipped", model.ancestor(pillar, "ROI/refined_anchors_clipped:0")),
("post_nms_anchor_ix", nms_node),
("proposals", model.keras_model.get_layer("ROI").output),
])
# Show top anchors by score (before refinement)
limit = 100
sorted_anchor_ids = np.argsort(rpn['rpn_class'][:,:,1].flatten())[::-1]
visualize.draw_boxes(image, boxes=model.anchors[sorted_anchor_ids[:limit]], ax=get_ax())
# Show top anchors with refinement. Then with clipping to image boundaries
limit = 50
ax = get_ax(1, 2)
pre_nms_anchors = utils.denorm_boxes(rpn["pre_nms_anchors"][0], image.shape[:2])
refined_anchors = utils.denorm_boxes(rpn["refined_anchors"][0], image.shape[:2])
refined_anchors_clipped = utils.denorm_boxes(rpn["refined_anchors_clipped"][0], image.shape[:2])
visualize.draw_boxes(image, boxes=pre_nms_anchors[:limit],
refined_boxes=refined_anchors[:limit], ax=ax[0])
visualize.draw_boxes(image, refined_boxes=refined_anchors_clipped[:limit], ax=ax[1])
# Show refined anchors after non-max suppression
limit = 50
ixs = rpn["post_nms_anchor_ix"][:limit]
visualize.draw_boxes(image, refined_boxes=refined_anchors_clipped[ixs], ax=get_ax())
# Show final proposals
# These are the same as the previous step (refined anchors
# after NMS) but with coordinates normalized to [0, 1] range.
limit = 50
# Convert back to image coordinates for display
h, w = config.IMAGE_SHAPE[:2]
proposals = rpn['proposals'][0, :limit] * np.array([h, w, h, w])
visualize.draw_boxes(image, refined_boxes=proposals, ax=get_ax())
This stage takes the region proposals from the RPN and classifies them.
Run the classifier heads on proposals to generate class propbabilities and bounding box regressions.
# Get input and output to classifier and mask heads.
mrcnn = model.run_graph([image], [
("proposals", model.keras_model.get_layer("ROI").output),
("probs", model.keras_model.get_layer("mrcnn_class").output),
("deltas", model.keras_model.get_layer("mrcnn_bbox").output),
("masks", model.keras_model.get_layer("mrcnn_mask").output),
("detections", model.keras_model.get_layer("mrcnn_detection").output),
])
# Get detection class IDs. Trim zero padding.
det_class_ids = mrcnn['detections'][0, :, 4].astype(np.int32)
det_count = np.where(det_class_ids == 0)[0][0]
det_class_ids = det_class_ids[:det_count]
detections = mrcnn['detections'][0, :det_count]
print("{} detections: {}".format(
det_count, np.array(dataset.class_names)[det_class_ids]))
captions = ["{} {:.3f}".format(dataset.class_names[int(c)], s) if c > 0 else ""
for c, s in zip(detections[:, 4], detections[:, 5])]
visualize.draw_boxes(
image,
refined_boxes=utils.denorm_boxes(detections[:, :4], image.shape[:2]),
visibilities=[2] * len(detections),
captions=captions, title="Detections",
ax=get_ax())
Here we dive deeper into the process of processing the detections.
# Proposals are in normalized coordinates. Scale them
# to image coordinates.
h, w = config.IMAGE_SHAPE[:2]
proposals = np.around(mrcnn["proposals"][0] * np.array([h, w, h, w])).astype(np.int32)
# Class ID, score, and mask per proposal
roi_class_ids = np.argmax(mrcnn["probs"][0], axis=1)
roi_scores = mrcnn["probs"][0, np.arange(roi_class_ids.shape[0]), roi_class_ids]
roi_class_names = np.array(dataset.class_names)[roi_class_ids]
roi_positive_ixs = np.where(roi_class_ids > 0)[0]
# How many ROIs vs empty rows?
print("{} Valid proposals out of {}".format(np.sum(np.any(proposals, axis=1)), proposals.shape[0]))
print("{} Positive ROIs".format(len(roi_positive_ixs)))
# Class counts
print(list(zip(*np.unique(roi_class_names, return_counts=True))))
# Display a random sample of proposals.
# Proposals classified as background are dotted, and
# the rest show their class and confidence score.
limit = 200
ixs = np.random.randint(0, proposals.shape[0], limit)
captions = ["{} {:.3f}".format(dataset.class_names[c], s) if c > 0 else ""
for c, s in zip(roi_class_ids[ixs], roi_scores[ixs])]
visualize.draw_boxes(image, boxes=proposals[ixs],
visibilities=np.where(roi_class_ids[ixs] > 0, 2, 1),
captions=captions, title="ROIs Before Refinement",
ax=get_ax())
# Class-specific bounding box shifts.
roi_bbox_specific = mrcnn["deltas"][0, np.arange(proposals.shape[0]), roi_class_ids]
log("roi_bbox_specific", roi_bbox_specific)
# Apply bounding box transformations
# Shape: [N, (y1, x1, y2, x2)]
refined_proposals = utils.apply_box_deltas(
proposals, roi_bbox_specific * config.BBOX_STD_DEV).astype(np.int32)
log("refined_proposals", refined_proposals)
# Show positive proposals
# ids = np.arange(roi_boxes.shape[0]) # Display all
limit = 5
ids = np.random.randint(0, len(roi_positive_ixs), limit) # Display random sample
captions = ["{} {:.3f}".format(dataset.class_names[c], s) if c > 0 else ""
for c, s in zip(roi_class_ids[roi_positive_ixs][ids], roi_scores[roi_positive_ixs][ids])]
visualize.draw_boxes(image, boxes=proposals[roi_positive_ixs][ids],
refined_boxes=refined_proposals[roi_positive_ixs][ids],
visibilities=np.where(roi_class_ids[roi_positive_ixs][ids] > 0, 1, 0),
captions=captions, title="ROIs After Refinement",
ax=get_ax())
# Remove boxes classified as background
keep = np.where(roi_class_ids > 0)[0]
print("Keep {} detections:\n{}".format(keep.shape[0], keep))
# Remove low confidence detections
keep = np.intersect1d(keep, np.where(roi_scores >= config.DETECTION_MIN_CONFIDENCE)[0])
print("Remove boxes below {} confidence. Keep {}:\n{}".format(
config.DETECTION_MIN_CONFIDENCE, keep.shape[0], keep))
# Apply per-class non-max suppression
pre_nms_boxes = refined_proposals[keep]
pre_nms_scores = roi_scores[keep]
pre_nms_class_ids = roi_class_ids[keep]
nms_keep = []
for class_id in np.unique(pre_nms_class_ids):
# Pick detections of this class
ixs = np.where(pre_nms_class_ids == class_id)[0]
# Apply NMS
class_keep = utils.non_max_suppression(pre_nms_boxes[ixs],
pre_nms_scores[ixs],
config.DETECTION_NMS_THRESHOLD)
# Map indicies
class_keep = keep[ixs[class_keep]]
nms_keep = np.union1d(nms_keep, class_keep)
print("{:22}: {} -> {}".format(dataset.class_names[class_id][:20],
keep[ixs], class_keep))
keep = np.intersect1d(keep, nms_keep).astype(np.int32)
print("\nKept after per-class NMS: {}\n{}".format(keep.shape[0], keep))
# Show final detections
ixs = np.arange(len(keep)) # Display all
# ixs = np.random.randint(0, len(keep), 10) # Display random sample
captions = ["{} {:.3f}".format(dataset.class_names[c], s) if c > 0 else ""
for c, s in zip(roi_class_ids[keep][ixs], roi_scores[keep][ixs])]
visualize.draw_boxes(
image, boxes=proposals[keep][ixs],
refined_boxes=refined_proposals[keep][ixs],
visibilities=np.where(roi_class_ids[keep][ixs] > 0, 1, 0),
captions=captions, title="Detections after NMS",
ax=get_ax())
This stage takes the detections (refined bounding boxes and class IDs) from the previous layer and runs the mask head to generate segmentation masks for every instance.
These are the training targets for the mask branch
display_images(np.transpose(gt_mask, [2, 0, 1]), cmap="Blues")
# Get predictions of mask head
mrcnn = model.run_graph([image], [
("detections", model.keras_model.get_layer("mrcnn_detection").output),
("masks", model.keras_model.get_layer("mrcnn_mask").output),
])
# Get detection class IDs. Trim zero padding.
det_class_ids = mrcnn['detections'][0, :, 4].astype(np.int32)
det_count = np.where(det_class_ids == 0)[0][0]
det_class_ids = det_class_ids[:det_count]
print("{} detections: {}".format(
det_count, np.array(dataset.class_names)[det_class_ids]))
# Masks
det_boxes = utils.denorm_boxes(mrcnn["detections"][0, :, :4], image.shape[:2])
det_mask_specific = np.array([mrcnn["masks"][0, i, :, :, c]
for i, c in enumerate(det_class_ids)])
det_masks = np.array([utils.unmold_mask(m, det_boxes[i], image.shape)
for i, m in enumerate(det_mask_specific)])
log("det_mask_specific", det_mask_specific)
log("det_masks", det_masks)
display_images(det_mask_specific[:4] * 255, cmap="Blues", interpolation="none")
display_images(det_masks[:4] * 255, cmap="Blues", interpolation="none")
In some cases it helps to look at the output from different layers and visualize them to catch issues and odd patterns.
# Get activations of a few sample layers
activations = model.run_graph([image], [
("input_image", tf.identity(model.keras_model.get_layer("input_image").output)),
("res2c_out", model.keras_model.get_layer("res2c_out").output),
("res3c_out", model.keras_model.get_layer("res3c_out").output),
("res4w_out", model.keras_model.get_layer("res4w_out").output), # for resnet100
("rpn_bbox", model.keras_model.get_layer("rpn_bbox").output),
("roi", model.keras_model.get_layer("ROI").output),
])
# Input image (normalized)
_ = plt.imshow(modellib.unmold_image(activations["input_image"][0],config))
# Backbone feature map
display_images(np.transpose(activations["res2c_out"][0,:,:,:4], [2, 0, 1]), cols=4)